#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Mon Jun 12 13:00:18 2017

@author: luohao
"""

import numpy as np
import os
from keras import optimizers
from keras.utils import np_utils, generic_utils
from keras.models import Sequential, Model, load_model
from keras.layers import Dropout, Flatten, Dense, Input, GlobalAveragePooling2D, GlobalMaxPooling2D, concatenate
from keras.applications.resnet50 import ResNet50
from keras import backend as K
from keras.layers.core import Lambda
from scipy.misc import imread, imresize

import dao

IMAGE_SIZE = 224
not_header = ['商品名称', '店铺', '商品毛重', '商品编号', '上市时间', '货号', '商品产地',
              '表盘颜色', '适用年龄', '尺码', '适用年龄', '销售渠道类型', '材质成分', '品牌', '品牌名称',
              '上市年份季节', '年份季节', '配件/备注', '产地', '洗涤说明', '面料成分', '材质', '大码女装分类', '分类', '适用季节',
              '面料', '成分含量', '质地', '尺寸', '服装款式细节', '材质1', '自定义', '材质3', '深灰色', '黑色', '深蓝', '深蓝2.0',
              '温馨提示', '藏青色', '灰色材质', '成分', '材质2', '蓝/黑色材质', '白色材质', '纤维成份', '里料', '厚薄', '厚度',
              '上市年份/季节', '年份/季节', '街头', '甜美', '主要颜色', '规格', '罩杯材质', '中老年女装分类', '穿着方式', '尺码', '品牌',
              '通勤', '流行元素', '中老年风格', '服饰工艺', '女装质地', '适用场景', '网纱', '组合形式', '毛线粗细', '花朵', '流行元素',
              '工艺', '面料分类', '风格', '中老年女装图案', '适用对象', '克重', '组合规格', '里料分类', '主要材质', '图案',
              '填充物', '牛仔面料盎司', '功能', '设计裁剪', '人群', '图案文化', '流行元素/工艺', '流行元素/工艺', '弹力', '2.0材质',
              '适用场合', '重量', '类型', '弹性', '款式', '选购热点', '适用人群', '适合人群', '面料成份'
              ]

# 3 6 15 6 5 6 5
catalog = [93, 94, 95]  # ['上装','裙装','裤装','特色服饰']
sleeve = ['长袖', '短袖', '无袖', '五分袖', '七分袖', '九分袖']  # 袖长
# collar = ['V领', '不规则领', '圆领', '低圆领', '其它', '高领', '半高领', '翻领', '连帽', '围巾领', 'A字型', '立领', '半高圆领',
#           '一字领', '方领', '旗袍领', 'POLO领', '堆堆领', '无领', '椭圆领', '半开领', 'U型领', '八字领', '高圆领', '低领', '常规领',
#           '娃娃领', '鸡心领', '平领', '双层领', '披肩领', '荡领', '橄榄领', '海军领', '大尖领', '系带领', '无', '荷叶领', '挂脖']  # 领型
collar = ['圆领', '其它', 'V领', '高领', '娃娃领', '立领', '方领', '翻领', '海军领', '半高领', 'POLO领', '连帽', '一字领', '半高圆领', '荷叶领']
coat_length = ['超短款', '短款', '常规款', '中长款', '长款', '不对称衣长']  # 衣长
dress_length = ['超短裙', '短裙', '中裙', '中长裙', '长裙']  # 裙长
outseam = ['短裤', '长裤', '热裤', '五分裤', '七分裤', '九分裤']  # 裤长
waistline = ['低腰', '中腰', '高腰', '松紧腰', '调节腰']  # 腰型


def get_good_ids(base_dir):
    return [name for name in os.listdir(base_dir) if os.path.isdir(os.path.join(base_dir, name))]


class ImageClass():
    "Stores the paths to images for a given class"

    def __init__(self, good_id, image_paths, catalog, attr):
        self.good_id = good_id
        self.image_paths = image_paths
        self.catalog = catalog
        self.attr = attr

    def __str__(self):
        return self.good_id + ', ' + str(len(self.image_paths)) + ' images' + ','

    def __len__(self):
        return len(self.image_paths)


def get_img_info(base_dir):
    images = []

    is_root = True
    good_ids = get_good_ids(base_dir)
    good_num = 0
    for good_id in good_ids:
        img_attr = []
        dir = os.path.join(base_dir, good_id)
        # catalog_categorical = np.zeros(len(catalog))
        # catalog_categorical = []
        # 上装、裙装、裤装
        # 袖长
        sleeve_categorical = np.zeros(len(sleeve))
        # sleeve_categorical = [0]*len(sleeve)
        # # 领型
        collar_categorical = np.zeros(len(collar))
        # # 衣长
        coat_length_categorical = np.zeros(len(coat_length))
        #
        # # 裙长
        dress_length_categorical = np.zeros(len(dress_length))
        # # 袖长
        # # 裤长
        outseam_categorical = np.zeros(len(outseam))
        # # 腰型
        waistline_categorical = np.zeros(len(waistline))

        # 图片路径
        image_paths = [os.path.join(dir, name) for name in os.listdir(dir)]
        str_sql = "select catalogID,description from t_product where sourceProductID='%s'" % good_id
        result = dao.select_one(str_sql)
        if result:
            # 商品分类
            catalog_id = int(result['catalogID'])
            # catalog_categorical[catalog.index(catalog_id)] = 1.
            # print(catalog_categorical)
            catalog_categorical = np_utils.to_categorical(catalog.index(catalog_id), len(catalog))
            # 商品属性
            try:
                attrs = eval(result['description'])
                for key in attrs.keys():
                    if key in not_header:
                        continue
                    attr = attrs.get(key)  # [0:3]
                    if key == '袖长':
                        if attr == '中袖':
                            attr = '五分袖'
                        if not catalog_id == 95:
                            sleeve_categorical[sleeve.index(attr)] = 1.
                    elif key == '领型':
                        collar_categorical[collar.index(attr)] = 1.
                    elif key == '衣长':
                        coat_length_categorical[coat_length.index(attr)] = 1.
                    elif key == '裙长':
                        dress_length_categorical[dress_length.index(attr)] = 1.
                    elif key == '裤长':
                        if catalog_id == 95:
                            outseam_categorical[outseam.index(attr)] = 1.
                    elif key == '腰型':
                        if catalog_id == 95:
                            waistline_categorical[waistline.index(attr)] = 1.
            except Exception as e:
                print('catalog_id:' + str(catalog_id) + ',attr eval error:' + key + ',good_id:' + good_id)
                print(e)
                continue
            # img_attr.extend(catalog_categorical)
            img_attr.extend(sleeve_categorical)
            img_attr.extend(collar_categorical)
            img_attr.extend(coat_length_categorical)
            img_attr.extend(dress_length_categorical)
            img_attr.extend(outseam_categorical)
            img_attr.extend(waistline_categorical)
            images.append(
                ImageClass(good_id, image_paths, catalog_categorical, np.array(img_attr).reshape(1, len(img_attr))))
    return images


def triplet_loss(y_true, y_pred):
    global PN
    y_pred = K.l2_normalize(y_pred, axis=1)
    batch = PN
    ref1 = y_pred[0:batch, :]
    pos1 = y_pred[batch:batch + batch, :]
    neg1 = y_pred[batch + batch:3 * batch, :]
    dis_pos = K.sum(K.square(ref1 - pos1), axis=1, keepdims=True)
    dis_neg = K.sum(K.square(ref1 - neg1), axis=1, keepdims=True)
    dis_pos = K.sqrt(dis_pos)
    dis_neg = K.sqrt(dis_neg)
    a1 = 0.2
    d1 = K.maximum(0.0, dis_pos - dis_neg + a1)
    return K.mean(d1)


def get_triplet_data(images_info, PN):
    nrof_classes = len(images_info)
    class_indices = np.arange(nrof_classes)
    np.random.shuffle(class_indices)

    images = np.zeros((3 * PN, IMAGE_SIZE, IMAGE_SIZE, 3))
    labels = np.zeros((3 * PN, len(images_info[0].attr[0])))

    for i in range(PN):
        class_index = class_indices[i]
        nrof_images_in_class = len(images_info[class_index])
        image_indices = np.arange(nrof_images_in_class)
        np.random.shuffle(image_indices)
        # a,p
        images[i, :, :, :] = get_img_array(images_info[class_index].image_paths[image_indices[0]])
        labels[i] = images_info[class_index].attr[0]
        images[i + PN, :, :, :] = get_img_array(images_info[class_index].image_paths[image_indices[0]])
        for ind in range(1, nrof_images_in_class):  # image_indices images_info[class_index].image_paths:
            path = images_info[class_index].image_paths[image_indices[ind]]
            if 'desc' in path:
                images[i + PN, :, :, :] = get_img_array(path)
                break
        # if nrof_images_in_class > 1:
        #     images[i + PN, :, :, :] = get_img_array(images_info[class_index].image_paths[image_indices[1]])
        # else:
        #     images[i + PN, :, :, :] = get_img_array(images_info[class_index].image_paths[image_indices[0]])
        labels[i + PN] = images_info[class_index].attr[0]

        # neg
        n_class_index = np.random.randint(nrof_classes - 1)
        if n_class_index == class_index:
            n_class_index += 1
        for path in images_info[n_class_index].image_paths:
            if 'desc' in path:
                images[i + PN * 2, :, :, :] = get_img_array(path)
                break
        # images[i + PN * 2, :, :, :] = get_img_array(
        #     images_info[n_class_index].image_paths[0])  # dataset[n_class_index].images[0]
        labels[i + PN * 2] = images_info[n_class_index].attr[0]
    return images, labels


def get_coupled_images(images_info, batch_num=30):
    p_images = []
    n_images = []
    nrof_classes = len(images_info)
    class_indices = np.arange(nrof_classes)
    np.random.shuffle(class_indices)
    p_images_info = images_info[class_indices[0]]
    if len(p_images_info.image_paths) < batch_num:
        batch_num = len(p_images_info.image_paths)
    p_labels = n_labels = np.zeros((batch_num, len(p_images_info.catalog) + len(p_images_info.attr[0])))
    p_class_indices = np.arange(batch_num)
    for i in range(batch_num):
        p_images.append(p_images_info.image_paths[p_class_indices[i]])
        p_labels[i][0:len(p_images_info.catalog)] = p_images_info.catalog
        p_labels[i][len(p_images_info.catalog):] = p_images_info.attr[0]

        n_images_info = images_info[class_indices[i + 1]]
        rand_int = np.random.randint(len(n_images_info.image_paths))
        n_images.append(n_images_info.image_paths[rand_int])
        n_labels[i][0:len(n_images_info.catalog)] = n_images_info.catalog
        n_labels[i][len(n_images_info.catalog):] = n_images_info.attr[0]

    return p_images, n_images, p_labels, n_labels


def get_random_batch(images_info, batch_num=50):
    nrof_classes = len(images_info)
    if nrof_classes <= batch_num:
        return images_info
    batch_images = []
    class_indices = np.arange(nrof_classes)
    np.random.shuffle(class_indices)
    for i in range(batch_num):
        batch_images.append(images_info[i])
    return batch_images


def get_img_array(img_path):
    img = imread(img_path)
    if not img.shape == (IMAGE_SIZE, IMAGE_SIZE, 3):
        img = imresize(IMAGE_SIZE, IMAGE_SIZE, 3)
    # img = img / 255.
    return img


def coupled_cluster_loss(p_encoded, n_encoded):
    p_encoded = K.l2_normalize(p_encoded)
    n_encoded = K.l2_normalize(n_encoded)
    p_center = K.mean(p_encoded)  # axis=0,列平均；axis=1，行平均；None，所有平均
    min_neg_to_cp = K.argmin(K.square(p_encoded - p_center))
    min_encoded = n_encoded[min_neg_to_cp]
    arf = 0.2
    return K.maximum(0, K.square(p_encoded - p_center) + arf - K.square(min_encoded - p_center)) / 2.


def coupledClusterModel(batch_size):
    p_inputs = Input(shape=(batch_size,IMAGE_SIZE, IMAGE_SIZE, 3))
    n_inputs = Input(shape=(batch_size,IMAGE_SIZE, IMAGE_SIZE, 3))
    # labels = np.ones((batch_size, 1))

    resnet_model = ResNet50(weights='imagenet', include_top=False, pooling='max',input_tensor=Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3)))
    # p_encoded = resnet_model.output
    p_encoded = resnet_model.predict_on_batch(p_inputs)
    n_encoded = resnet_model.predict_on_batch(n_inputs)
    cata_X = concatenate([p_encoded, n_encoded])
    # cata_X = GlobalMaxPooling2D()(cata_X)
    cata_X = Dense(512, activation='relu')(cata_X)
    cata_pred = Dense(3, activation='softmax', name='fc_catalog')(cata_X)
    attr_pred = Dense(60, activation='sigmoid', name='fc_attr')(cata_X)
    loss = coupled_cluster_loss(p_encoded, n_encoded)
    model = Model(inputs=[p_inputs, n_inputs], outputs=[cata_pred, attr_pred, loss])
    return model


if __name__ == '__main__':
    print(len(sleeve),len(collar),len(coat_length),len(dress_length),len(outseam),len(waistline))
    # model = coupledClusterModel(5)
    # model.compile(optimizer='adam',
    #               loss=['categorical_crossentropy', 'binary_crossentropy', lambda y_true, y_pred: y_pred],
    #               loss_weights=[1.0, 1.0, 1.0], metrics={'fc_catalog': 'accuracy', 'fc_attr': 'accuracy'})

    images = get_img_info('E:/cdn/datasets/triplet')
    p_images, n_images, p_labels, n_labels = get_coupled_images(images, 20)
    print(p_labels.shape,n_labels.shape)
    for i,images in enumerate(p_images):
        print(i,images)
        print(i,p_labels[i])
    # print(p_images, p_labels)
    # random_batch = 50s
    # model.fit()
